package jwt

import (
	"crypto/rsa"
	"errors"
	"gitee.com/gopher2011/gin"
	"github.com/dgrijalva/jwt-go"
	"io/ioutil"
	"net/http"
	"strings"
	"time"
)

var (
	// ErrMissingSecretKey 表示需要密钥
	ErrMissingSecretKey = errors.New("需要密钥")

	// ErrForbidden when HTTP status 403 is given
	ErrForbidden = errors.New("您无权访问此资源")

	// ErrMissingAuthFunc indicates Authenticator is required
	ErrMissingAuthFunc = errors.New(" JWTMiddleware.AuthFunc函数未定义!")

	// ErrMissingLoginValues indicates a user tried to authenticate without username or password
	ErrMissingLoginValues = errors.New("缺少用户名或密码")

	// ErrFailedAuthentication indicates authentication failed, could be faulty username or password
	ErrFailedAuthentication = errors.New("用户名或密码错误")

	// ErrFailedTokenCreation indicates JWT Token failed to create, reason unknown
	ErrFailedTokenCreation = errors.New("无法创建JWT Token")

	// ErrExpiredToken indicates JWT token has expired. Can't refresh.
	ErrExpiredToken = errors.New("token已过期")

	// ErrEmptyAuthHeader can be thrown if authing with a HTTP header, the Auth header needs to be set
	ErrEmptyAuthHeader = errors.New("auth标头为空")

	// ErrMissingExpField missing exp field in token
	ErrMissingExpField = errors.New("缺少exp字段")

	// ErrWrongFormatOfExp field must be float64 format
	ErrWrongFormatOfExp = errors.New("exp必须为float64格式")

	// ErrInvalidAuthHeader indicates auth header is invalid, could for example have the wrong Realm name
	ErrInvalidAuthHeader = errors.New("auth标头无效")

	// ErrEmptyQueryToken can be thrown if authing with URL Query, the query token variable is empty
	ErrEmptyQueryToken = errors.New("查询token为空")

	// ErrEmptyCookieToken can be thrown if authing with a cookie, the token cookie is empty
	ErrEmptyCookieToken = errors.New(" Cookie token为空")

	// ErrEmptyParamToken can be thrown if authing with parameter in path, the parameter in path is empty
	ErrEmptyParamToken = errors.New("参数token为空")

	// ErrInvalidSigningAlgorithm indicates signing algorithm is invalid, needs to be HS256, HS384, HS512, RS256, RS384 or RS512
	ErrInvalidSigningAlgorithm = errors.New("无效签名算法")

	// ErrNoPrivateKeyFile indicates that the given private key is unreadable
	ErrNoPrivateKeyFile = errors.New("私钥文件不可读")

	// ErrNoPubKeyFile indicates that the given public key is unreadable
	ErrNoPubKeyFile = errors.New("公钥文件不可读")

	// ErrInvalidPrivateKey indicates that the given private key is invalid
	ErrInvalidPrivateKey = errors.New("私钥无效")

	// ErrInvalidPubKey indicates the the given public key is invalid
	ErrInvalidPubKey = errors.New("公钥无效")

	// IdentityKey default identity key
	IdentityKey = "identity"
)

//使用 map[string]interface {}进行JSON解码,(默认的声明类型)
type MapClaims map[string]interface{}

//提供了Json-Web-Token身份验证实现。失败时，将返回401 HTTP响应.
//  成功后，将调用包装的中间件，并以c.Get("userID").(string)的形式提供userID。
//  用户可以通过将json请求发布到LoginHandler来获得令牌。然后需要在Authentication标头中传递令牌
//  例如:Authorization:Bearer XXX_TOKEN_XXX
type JWTMiddleware struct {
	Realm string //显示给用户的名称,(必须参数)

	SigningAlgorithm string	//(可选参数)签名算法-可能的值为HS256，HS384，HS512，RS256，RS384或RS512,默认为HS256。

	Key []byte	//用于签名的密钥

	Timeout time.Duration	//jwt令牌有效的持续时间。可选，默认为一小时。

	//(可选参数)该字段允许客户端刷新令牌，直到MaxRefresh通过。
	//  请注意:客户端可以在MaxRefresh的最后时刻刷新其令牌。
	//  这意味着令牌的最大有效时间跨度为TokenTime + MaxRefresh。
	//  默认为0表示不可刷新。
	MaxRefresh time.Duration

	//(必须参数)基于登录信息执行用户身份验证的回调函数。
	//	必须返回用户数据作为用户标识符，它将存储在Claim Array中。
	AuthFunc func(c *gin.Context) (interface{}, error)

	//(可选参数)回调功能，应执行经过身份验证的用户的授权。仅在身份验证成功后调用。
	//	成功时必须返回true，失败时必须返回false。默认为成功。
	AuthAfter func(data interface{}, c *gin.Context) bool

	//登录期间将调用的回调函数。
	//	使用此功能可以将其他有效负载数据添加到JWT Token
	//	然后在请求期间通过c.Get("JWT_PAYLOAD")使数据可用。
	//	请注意，有效负载未加密。
	//	jwt.io上提到的属性不能用作map的键。
	PayloadFunc func(data interface{}) MapClaims

	//用户可以定义自己的未经授权的功能。
	UnAuthFunc func(*gin.Context, int, string)

	//用户可以定义自己的 LoginResponse 函数。
	LoginResponse func(*gin.Context, int, string, time.Time)

	//用户可以定义自己的 LogoutResponse 函数。
	LogoutResponse func(*gin.Context, int)

	//用户可以定义自己的 RefreshResponse 函数。
	RefreshResponse func(*gin.Context, int, string, time.Time)

	//设置身份处理程序功能
	IdentityHandler func(*gin.Context) interface{}

	// 设置身份密钥
	IdentityKey string

	//(可选参数)是"<source>:<name>"形式的字符串,用于从请求中提取令牌。(默认值"header:Authorization")
	//可选值:
	// - "header:<name>"
	// - "query:<name>"
	// - "cookie:<name>"
	TokenLookup string

	//标头中的字符串。默认值为"Bearer"
	TokenHeadName string

	// TimeFunc 提供当前时间。您可以覆盖它以使用其他时间值。这对于测试或服务器使用不同于令牌的时区很有用。
	TimeFunc func() time.Time

	// 当JWT中间件发生故障时的HTTP状态消息。
	HTTPStatusMsgFunc func(e error, c *gin.Context) string

	// 非对称算法的私钥文件
	PrivateKeyFile string

	//非对称算法的私钥字节
	//	注意:如果同时设置了PrivateKeyFile，则PrivateKeyFile优先于PrivateKeyByte
	PrivateKeyByte []byte

	// 非对称算法的公钥文件
	PubKeyFile string

	// 非对称算法的公钥字节。
	//	注意:如果同时设置了 PubKeyFile，则 PubKeyFile 优先于 PubKeyByte
	PubKeyByte []byte

	// Private key
	privateKey *rsa.PrivateKey

	// Public key
	pubKey *rsa.PublicKey

	// (可选)将Token作为Cookie返回
	SendCookie bool

	// Cookie有效的持续时间。可选,默认情况下等于 Timeout 的值。
	CookieMaxAge time.Duration

	// 允许不安全的Cookie通过HTTP进行开发
	SecureCookie bool

	// 允许访问客户端的Cookie进行开发
	CookieHTTPOnly bool

	// 允许更改Cookie域以进行开发
	CookieDomain string

	// SendAuthorization 允许每个请求的返回授权标头
	SendAuthorization bool

	// 禁用上下文的abort()。
	DisabledAbort bool

	// CookieName 允许更改Cookie名称以进行开发
	CookieName string

	// CookieSameSite 允许使用http.SameSite Cookie参数
	CookieSameSite http.SameSite
}

// ExtractClaims 帮助提取JWT的Claims
func ExtractClaims(c *gin.Context) MapClaims {
	claims, exists := c.Get("JWT_PAYLOAD")
	if !exists {
		return make(MapClaims)
	}

	return claims.(MapClaims)
}

func (mw *JWTMiddleware) usingPublicKeyAlgo() bool {
	switch mw.SigningAlgorithm {
	case "RS256", "RS512", "RS384":
		return true
	}
	return false
}

func (mw *JWTMiddleware) priKey() error {
	var keyData []byte
	if mw.PrivateKeyFile == "" {
		keyData = mw.PrivateKeyByte
	} else {
		file, err := ioutil.ReadFile(mw.PrivateKeyFile)
		if err != nil {
			return ErrNoPrivateKeyFile
		}
		keyData = file
	}

	key, err := jwt.ParseRSAPrivateKeyFromPEM(keyData)
	if err != nil {
		return ErrInvalidPrivateKey
	}
	mw.privateKey = key
	return nil
}

func (mw *JWTMiddleware) publicKey() error {
	var keyData []byte
	if mw.PubKeyFile == "" {
		keyData = mw.PubKeyByte
	} else {
		file, err := ioutil.ReadFile(mw.PubKeyFile)
		if err != nil {
			return ErrNoPubKeyFile
		}
		keyData = file
	}
	key, err := jwt.ParseRSAPublicKeyFromPEM(keyData)
	if err != nil {
		return ErrInvalidPubKey
	}
	mw.pubKey = key
	return nil
}

func (mw *JWTMiddleware) readKey() error {
	err := mw.priKey()
	if err != nil {
		return err
	}
	err = mw.publicKey()
	if err != nil {
		return err
	}
	return nil
}

// MiddlewareInit initialize jwt configs.
func (mw *JWTMiddleware) Init() error {

	if mw.TokenLookup == "" {
		mw.TokenLookup = "header:Authorization"
	}

	if mw.SigningAlgorithm == "" {
		mw.SigningAlgorithm = "HS256"
	}

	if mw.Timeout == 0 {
		mw.Timeout = time.Hour
	}

	if mw.TimeFunc == nil {
		mw.TimeFunc = time.Now
	}

	mw.TokenHeadName = strings.TrimSpace(mw.TokenHeadName)
	if len(mw.TokenHeadName) == 0 {
		mw.TokenHeadName = "Bearer"
	}

	if mw.AuthAfter == nil {
		mw.AuthAfter = func(data interface{}, c *gin.Context) bool {
			return true
		}
	}

	if mw.UnAuthFunc == nil {
		mw.UnAuthFunc = func(c *gin.Context, code int, message string) {
			c.JSON(code, gin.H{
				"code":    code,
				"message": message,
			})
		}
	}

	if mw.LoginResponse == nil {
		mw.LoginResponse = func(c *gin.Context, code int, token string, expire time.Time) {
			c.JSON(http.StatusOK, gin.H{
				"code":   http.StatusOK,
				"token":  token,
				"expire": expire.Format(time.RFC3339),
			})
		}
	}

	if mw.LogoutResponse == nil {
		mw.LogoutResponse = func(c *gin.Context, code int) {
			c.JSON(http.StatusOK, gin.H{
				"code": http.StatusOK,
			})
		}
	}

	if mw.RefreshResponse == nil {
		mw.RefreshResponse = func(c *gin.Context, code int, token string, expire time.Time) {
			c.JSON(http.StatusOK, gin.H{
				"code":   http.StatusOK,
				"token":  token,
				"expire": expire.Format(time.RFC3339),
			})
		}
	}

	if mw.IdentityKey == "" {
		mw.IdentityKey = IdentityKey
	}

	if mw.IdentityHandler == nil {
		mw.IdentityHandler = func(c *gin.Context) interface{} {
			claims := ExtractClaims(c)
			return claims[mw.IdentityKey]
		}
	}

	if mw.HTTPStatusMsgFunc == nil {
		mw.HTTPStatusMsgFunc = func(e error, c *gin.Context) string {
			return e.Error()
		}
	}

	if mw.Realm == "" {
		mw.Realm = "gin jwt"
	}

	if mw.CookieMaxAge == 0 {
		mw.CookieMaxAge = mw.Timeout
	}

	if mw.CookieName == "" {
		mw.CookieName = "jwt"
	}

	if mw.usingPublicKeyAlgo() {
		return mw.readKey()
	}

	if mw.Key == nil {
		return ErrMissingSecretKey
	}
	return nil
}
// New for check error with GinJWTMiddleware
func New(m *JWTMiddleware) (*JWTMiddleware, error) {
	if err := m.Init(); err != nil {
		return nil, err
	}
	return m, nil
}

func (mw *JWTMiddleware) jwtFromHeader(c *gin.Context, key string) (string, error) {
	authHeader := c.Request.Header.Get(key)

	if authHeader == "" {
		return "", ErrEmptyAuthHeader
	}

	parts := strings.SplitN(authHeader, " ", 2)
	if !(len(parts) == 2 && parts[0] == mw.TokenHeadName) {
		return "", ErrInvalidAuthHeader
	}

	return parts[1], nil
}

func (mw *JWTMiddleware) jwtFromQuery(c *gin.Context, key string) (string, error) {
	token := c.Query(key)

	if token == "" {
		return "", ErrEmptyQueryToken
	}

	return token, nil
}

func (mw *JWTMiddleware) jwtFromCookie(c *gin.Context, key string) (string, error) {
	cookie, _ := c.Cookie(key)

	if cookie == "" {
		return "", ErrEmptyCookieToken
	}

	return cookie, nil
}

func (mw *JWTMiddleware) jwtFromParam(c *gin.Context, key string) (string, error) {
	token := c.Param(key)

	if token == "" {
		return "", ErrEmptyParamToken
	}

	return token, nil
}

// ParseToken 从 gin.Context 解析jwt令牌
func (mw *JWTMiddleware) ParseToken(c *gin.Context) (*jwt.Token, error) {
	var token string
	var err error

	methods := strings.Split(mw.TokenLookup, ",")
	for _, method := range methods {
		if len(token) > 0 {
			break
		}
		parts := strings.Split(strings.TrimSpace(method), ":")
		k := strings.TrimSpace(parts[0])
		v := strings.TrimSpace(parts[1])
		switch k {
		case "header":
			token, err = mw.jwtFromHeader(c, v)
		case "query":
			token, err = mw.jwtFromQuery(c, v)
		case "cookie":
			token, err = mw.jwtFromCookie(c, v)
		case "param":
			token, err = mw.jwtFromParam(c, v)
		}
	}

	if err != nil {
		return nil, err
	}

	return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
		if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method {
			return nil, ErrInvalidSigningAlgorithm
		}
		if mw.usingPublicKeyAlgo() {
			return mw.pubKey, nil
		}

		// save token string if vaild
		c.Set("JWT_TOKEN", token)

		return mw.Key, nil
	})
}

// GetClaimsFromJWT get claims from JWT token
func (mw *JWTMiddleware) GetClaimsFromJWT(c *gin.Context) (MapClaims, error) {
	token, err := mw.ParseToken(c)

	if err != nil {
		return nil, err
	}

	if mw.SendAuthorization {
		if v, ok := c.Get("JWT_TOKEN"); ok {
			c.Header("Authorization", mw.TokenHeadName+" "+v.(string))
		}
	}

	claims := MapClaims{}
	for key, value := range token.Claims.(jwt.MapClaims) {
		claims[key] = value
	}

	return claims, nil
}

func (mw *JWTMiddleware) unauthorized(c *gin.Context, code int, message string) {
	c.Header("WWW-Authenticate", "JWT realm="+mw.Realm)
	if !mw.DisabledAbort {
		c.Abort()
	}
	mw.UnAuthFunc(c, code, message)
}

func (mw *JWTMiddleware) middlewareImpl(c *gin.Context) {
	claims, err := mw.GetClaimsFromJWT(c)
	if err != nil {
		mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMsgFunc(err, c))
		return
	}

	if claims["exp"] == nil {
		mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMsgFunc(ErrMissingExpField, c))
		return
	}

	if _, ok := claims["exp"].(float64); !ok {
		mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMsgFunc(ErrWrongFormatOfExp, c))
		return
	}

	if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() {
		mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMsgFunc(ErrExpiredToken, c))
		return
	}

	c.Set("JWT_PAYLOAD", claims)
	identity := mw.IdentityHandler(c)

	if identity != nil {
		c.Set(mw.IdentityKey, identity)
	}

	if !mw.AuthAfter(identity, c) {
		mw.unauthorized(c, http.StatusForbidden, mw.HTTPStatusMsgFunc(ErrForbidden, c))
		return
	}
	c.Next()
}
// MiddlewareFunc 使 JWTMiddleware 实现 Middleware 接口。
func (mw *JWTMiddleware) MiddlewareFunc() gin.HandlerFunc {
	return func(c *gin.Context) {
		mw.middlewareImpl(c)
	}
}

func (mw *JWTMiddleware) signedString(token *jwt.Token) (string, error) {
	var tokenString string
	var err error
	if mw.usingPublicKeyAlgo() {
		tokenString, err = token.SignedString(mw.privateKey)
	} else {
		tokenString, err = token.SignedString(mw.Key)
	}
	return tokenString, err
}

// LoginHandler 可以被客户端用来获取jwt令牌。
//	有效负载必须为{"username":"username","password":"password"}形式的json。
//  回复的格式为{"token":"token"}。
func (mw *JWTMiddleware) LoginHandler(c *gin.Context) {
	if mw.AuthFunc == nil {
		mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMsgFunc(ErrMissingAuthFunc, c))
		return
	}

	data, err := mw.AuthFunc(c)

	if err != nil {
		mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMsgFunc(err, c))
		return
	}

	// Create the token
	token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	claims := token.Claims.(jwt.MapClaims)

	if mw.PayloadFunc != nil {
		for key, value := range mw.PayloadFunc(data) {
			claims[key] = value
		}
	}

	expire := mw.TimeFunc().Add(mw.Timeout)
	claims["exp"] = expire.Unix()
	claims["orig_iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(token)

	if err != nil {
		mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMsgFunc(ErrFailedTokenCreation, c))
		return
	}

	// set cookie
	if mw.SendCookie {
		expireCookie := mw.TimeFunc().Add(mw.CookieMaxAge)
		maxage := int(expireCookie.Unix() - mw.TimeFunc().Unix())

		if mw.CookieSameSite != 0 {
			c.SetSameSite(mw.CookieSameSite)
		}

		c.SetCookie(
			mw.CookieName,
			tokenString,
			maxage,
			"/",
			mw.CookieDomain,
			mw.SecureCookie,
			mw.CookieHTTPOnly,
		)
	}

	mw.LoginResponse(c, http.StatusOK, tokenString, expire)
}
// LogoutHandler 可以被客户端用来删除jwt cookie（如果已设置）
func (mw *JWTMiddleware) LogoutHandler(c *gin.Context) {
	// delete auth cookie
	if mw.SendCookie {
		if mw.CookieSameSite != 0 {
			c.SetSameSite(mw.CookieSameSite)
		}
		c.SetCookie(
			mw.CookieName,
			"",
			-1,
			"/",
			mw.CookieDomain,
			mw.SecureCookie,
			mw.CookieHTTPOnly,
		)
	}
	mw.LogoutResponse(c, http.StatusOK)
}
// CheckIfTokenExpire 检查token是否过期
func (mw *JWTMiddleware) CheckIfTokenExpire(c *gin.Context) (jwt.MapClaims, error) {
	token, err := mw.ParseToken(c)

	if err != nil {
		// If we receive an error, and the error is anything other than a single
		// ValidationErrorExpired, we want to return the error.
		// If the error is just ValidationErrorExpired, we want to continue, as we can still
		// refresh the token if it's within the MaxRefresh time.
		// (see https://github.com/appleboy/gin-jwt/issues/176)
		validationErr, ok := err.(*jwt.ValidationError)
		if !ok || validationErr.Errors != jwt.ValidationErrorExpired {
			return nil, err
		}
	}

	claims := token.Claims.(jwt.MapClaims)

	origIat := int64(claims["orig_iat"].(float64))

	if origIat < mw.TimeFunc().Add(-mw.MaxRefresh).Unix() {
		return nil, ErrExpiredToken
	}

	return claims, nil
}
// RefreshToken 刷新token并检查token是否过期
func (mw *JWTMiddleware) RefreshToken(c *gin.Context) (string, time.Time, error) {
	claims, err := mw.CheckIfTokenExpire(c)
	if err != nil {
		return "", time.Now(), err
	}

	// Create the token
	newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	newClaims := newToken.Claims.(jwt.MapClaims)

	for key := range claims {
		newClaims[key] = claims[key]
	}

	expire := mw.TimeFunc().Add(mw.Timeout)
	newClaims["exp"] = expire.Unix()
	newClaims["orig_iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(newToken)

	if err != nil {
		return "", time.Now(), err
	}

	// set cookie
	if mw.SendCookie {
		expireCookie := mw.TimeFunc().Add(mw.CookieMaxAge)
		maxage := int(expireCookie.Unix() - time.Now().Unix())

		if mw.CookieSameSite != 0 {
			c.SetSameSite(mw.CookieSameSite)
		}

		c.SetCookie(
			mw.CookieName,
			tokenString,
			maxage,
			"/",
			mw.CookieDomain,
			mw.SecureCookie,
			mw.CookieHTTPOnly,
		)
	}

	return tokenString, expire, nil
}
// RefreshHandler 可用于刷新token。token在刷新时仍然需要有效。
//	应放置在使用 JWTMiddleware 的端点下。
//	回复的格式为{"token":"token"}。
func (mw *JWTMiddleware) RefreshHandler(c *gin.Context) {
	tokenString, expire, err := mw.RefreshToken(c)
	if err != nil {
		mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMsgFunc(err, c))
		return
	}

	mw.RefreshResponse(c, http.StatusOK, tokenString, expire)
}

// TokenGenerator 客户端可以用来获取jwt token的方法。
func (mw *JWTMiddleware) TokenGenerate(data interface{}) (string, time.Time, error) {
	token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	claims := token.Claims.(jwt.MapClaims)

	if mw.PayloadFunc != nil {
		for key, value := range mw.PayloadFunc(data) {
			claims[key] = value
		}
	}

	expire := mw.TimeFunc().UTC().Add(mw.Timeout)
	claims["exp"] = expire.Unix()
	claims["orig_iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(token)
	if err != nil {
		return "", time.Time{}, err
	}
	return tokenString, expire, nil
}

// ParseTokenString 解析jwt token字符串
func (mw *JWTMiddleware) ParseTokenStr(token string) (*jwt.Token, error) {
	return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
		if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method {
			return nil, ErrInvalidSigningAlgorithm
		}
		if mw.usingPublicKeyAlgo() {
			return mw.pubKey, nil
		}
		return mw.Key, nil
	})
}

// ExtractClaimsFromToken 帮助从token中提取JWT Claims
func ExtractClaimsFromToken(token *jwt.Token) MapClaims {
	if token == nil {
		return make(MapClaims)
	}
	claims := MapClaims{}
	for key, value := range token.Claims.(jwt.MapClaims) {
		claims[key] = value
	}
	return claims
}

// GetToken 帮助获取JWT token字符串
func GetToken(c *gin.Context) string {
	token, exists := c.Get("JWT_TOKEN")
	if !exists {
		return ""
	}
	return token.(string)
}